import torch
import os
from torch.utils.data import Dataset, DataLoader
import glob
from sklearn.model_selection import train_test_split
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
from generate_T import find_invertible_submatrix, Generator_matrix
from Random_T import Random_matrix



random.seed(2024)
np.random.seed(2024)
class CustomDataset(Dataset):
    def __init__(self, sequences1, labels):
        self.sequences1 = sequences1
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, index):
        sequence1 = self.sequences1[index]
        label = self.labels[index]
        return sequence1, label

def give_batch(path):
    train_x = []
    train_y = []
    test_size=0.2
    with open(path, 'r', encoding="utf-8") as f:
        data = f.readlines()
    for line in data:
        datas = line.strip().split('\t')[0]
        label = int(line.strip().split('\t')[1])
        lines = [int(x) for x in datas.strip().split(',')]
        dd = [0 if ddd=='' else ddd for ddd in lines]
        train_x.append(lines)
        train_y.append(label)
    X_train, X_test, y_train, y_test = train_test_split(train_x, train_y, test_size=test_size,random_state=42,shuffle=True)
    return X_train, X_test, y_train, y_test

def heatmapVisual(model,dataloader,data_length,device,log_dir):
    accumulated_matrix = np.zeros((data_length, data_length))
    for x,y in dataloader:
        x = x.long().to(device)
        E = model.embedding1(x)
        A= torch.matmul(E, E.transpose(-2, -1))
        attention_matrix = model.activate(A).detach().cpu().numpy()
        #attention_matrix = model.relu(A).detach().cpu().numpy()
        #x_embedding1 = model.embedding1(x)
        #x_embedding2 = model.embedding2(x)
        #for layer in model.multihead_layers:                                         
            #x_embedded3 = layer(x_embedding2, x_embedding1, x_embedding1)  # Q=K=V for self-attention
        #attention_matrix = model.relu(x_embedded3).squeeze(0).detach().cpu().numpy()
        for i in range(attention_matrix.shape[0]):
            current_matrix = attention_matrix[i]
            normalized_matrix = (current_matrix - np.min(current_matrix)) / (np.max(current_matrix) - np.min(current_matrix))
            accumulated_matrix += normalized_matrix
        np.savetxt(os.path.join(log_dir, 'accumulated_matrix.txt'), accumulated_matrix)
        # threshold = np.median(accumulated_matrix)
        # binary_accumulated_matrix1 = np.where(accumulated_matrix > threshold, 1, 0)
        plt.figure(figsize=(10, 8))
        sns.heatmap(accumulated_matrix, cmap='viridis')
        #sns.heatmap(binary_accumulated_matrix1, cmap='viridis')
        plt.title("Accumulated Normalized Attention Matrix Heatmap")
        plt.xlabel("Embedding Dimension")
        plt.ylabel("Sequence Length")
        plt.savefig(os.path.join(log_dir,'accumulated_attention_matrix_heatmap.png'), bbox_inches='tight')
        plt.close()
        column_sum1 = np.sum(accumulated_matrix, axis=0)
        normalized_column_sum1 = (column_sum1 - np.min(column_sum1)) / (np.max(column_sum1) - np.min(column_sum1))
        heatmap_matrix1 = np.expand_dims(normalized_column_sum1, axis=0) 
        plt.figure(figsize=(12, 2))
        sns.heatmap(heatmap_matrix1, cmap='viridis', cbar=True, annot=True)
        plt.title("Column-wise Summed and Normalized Feature Heatmap")
        plt.xlabel("Embedding Dimension")
        plt.yticks([])
        plt.savefig(os.path.join(log_dir,'column_summed_normalized_heatmap.png'), bbox_inches='tight')
        plt.close()

def getClassificationMetrics(model, dataloader, device, log_dir, prefix):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    for x_test, y_test in dataloader:
        x_test = x_test.long().to(device)
        y_test = y_test.long().to(device)
        with torch.no_grad():
            outputs = model(x_test)
            probas = outputs.softmax(dim=1)
            all_probs.extend(probas[:, 1].cpu().numpy())
            _, predictions = torch.max(outputs, 1)
        all_preds.extend(predictions.cpu().numpy())
        all_labels.extend(y_test.cpu().numpy())
    all_probs = np.array(all_probs)
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=1, pos_label=1)
    recall = recall_score(all_labels, all_preds, zero_division=1, pos_label=1)
    f1 = f1_score(all_labels, all_preds, zero_division=1, pos_label=1)
    conf_matrix = confusion_matrix(all_labels, all_preds)
    TN = conf_matrix[0, 0]
    TP = conf_matrix[1, 1]
    FP = conf_matrix[0, 1]
    FN = conf_matrix[1, 0]
    specificity = TN / (TN + FP) 
    auc = roc_auc_score(all_labels, all_probs)
    return accuracy, precision, recall, f1, specificity, auc, conf_matrix

def plot_confusion_matrix(conf_matrix, log_dir, prefix):
    plt.figure(figsize=(10, 7))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.savefig(os.path.join(log_dir, f"{prefix}_confusion_matrix.png"))
    plt.close()
    
def plot_accuracy(train_accuracy_list, test_accuracy_list, epochs, log_dir):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, epochs + 1), train_accuracy_list, label='Train Accuracy')
    plt.plot(range(1, epochs + 1), test_accuracy_list, label='Test Accuracy')
    plt.title('Accuracy over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(log_dir, 'accuracy_plot.png'))
    plt.close()